Python at Scale
The Serialisation Tax That Kills Your Speedup
This looks like an efficient parallel computation:
import multiprocessing
import time
def process_item(item: bytes) -> int:
"""Process a small binary item - 10 bytes each."""
return sum(item)
items_small = [bytes([i % 256]) * 10 for i in range(100_000)] # 100k × 10 bytes
# Sequential
start = time.perf_counter()
results_seq = [process_item(item) for item in items_small]
seq_time = time.perf_counter() - start
# Parallel - 8 workers
start = time.perf_counter()
with multiprocessing.Pool(8) as pool:
results_par = pool.map(process_item, items_small)
par_time = time.perf_counter() - start
print(f"Sequential: {seq_time:.3f}s")
print(f"Parallel: {par_time:.3f}s")
print(f"Ratio: {par_time / seq_time:.1f}x {'SLOWER' if par_time > seq_time else 'faster'}")
Sequential: 0.008s
Parallel: 0.743s
Ratio: 92.8x SLOWER
The parallel version is 93x slower. Eight workers and negative speedup.
The cause: inter-process communication. Each item must be serialised (pickled) by the parent process, sent via a pipe to a worker process, deserialised (unpickled), processed, and the result serialised and sent back. For 10-byte items, the serialisation overhead dwarfs the computation cost.
Now the same pattern with large items:
items_large = [bytes([i % 256]) * 100_000 for i in range(1_000)] # 1k × 100KB
start = time.perf_counter()
results_seq = [process_item(item) for item in items_large]
seq_time = time.perf_counter() - start
start = time.perf_counter()
with multiprocessing.Pool(8) as pool:
results_par = pool.map(process_item, items_large)
par_time = time.perf_counter() - start
print(f"Sequential: {seq_time:.3f}s")
print(f"Parallel: {par_time:.3f}s")
print(f"Speedup: {seq_time / par_time:.1f}x")
Sequential: 1.823s
Parallel: 0.253s
Speedup: 7.2x ← near-linear on 8 cores
The crossover point on this machine is approximately 1KB per item. Below that threshold, serialisation overhead exceeds the parallelism benefit. Above it, you get near-linear scaling.
| Item Size | Sequential | Parallel (8 workers) | Result |
|---|---|---|---|
| 10 bytes | 0.008s | 0.743s | 93x SLOWER |
| 1 KB | 0.21s | 0.23s | ~Breakeven |
| 100 KB | 1.82s | 0.25s | 7.2x faster |
| 10 MB | 18.4s | 2.6s | 7.1x faster |
This lesson teaches you to pick the right tool for each scale challenge and avoid the serialisation trap.
What You Will Learn
- Understand why the GIL forces multiprocessing for CPU-bound Python parallelism
- Use
multiprocessing.PoolandProcessPoolExecutorwith correct chunking - Share large arrays across processes without copying using
SharedMemory - Use Ray for distributed actor-based computation
- Use Dask for lazy computation graphs over datasets larger than RAM
- Use Celery for background task queues with retry logic
- Choose the right tool for each scale problem
Prerequisites
| Requirement | Level Needed |
|---|---|
| Python functions and modules | Comfortable |
| NumPy array basics | Familiar |
| Basic OS/process concepts | Familiar |
| asyncio basics (for Celery) | Helpful |
Section 1: The GIL and True Parallelism
CPython's GIL (Global Interpreter Lock) is a mutex that must be held to execute Python bytecode. Only one thread executes Python at a time, regardless of core count.
import threading
import time
import multiprocessing
def cpu_work(n: int) -> float:
"""Pure Python CPU computation."""
return sum(i * i for i in range(n))
N = 5_000_000
# Sequential
start = time.perf_counter()
cpu_work(N)
cpu_work(N)
seq_time = time.perf_counter() - start
# Threading - two threads, one CPU (GIL)
def run_threaded():
t1 = threading.Thread(target=cpu_work, args=(N,))
t2 = threading.Thread(target=cpu_work, args=(N,))
t1.start(); t2.start()
t1.join(); t2.join()
start = time.perf_counter()
run_threaded()
thread_time = time.perf_counter() - start
# Multiprocessing - two processes, two CPUs
def run_multiprocessing():
with multiprocessing.Pool(2) as pool:
pool.map(cpu_work, [N, N])
start = time.perf_counter()
run_multiprocessing()
mp_time = time.perf_counter() - start
print(f"Sequential: {seq_time:.2f}s")
print(f"Threading (2): {thread_time:.2f}s ({seq_time/thread_time:.1f}x)")
print(f"Multiprocessing: {mp_time:.2f}s ({seq_time/mp_time:.1f}x)")
Sequential: 2.41s
Threading (2): 2.63s (0.92x) - SLOWER due to GIL contention
Multiprocessing: 1.28s (1.88x) - actual parallel execution
Threading makes CPU-bound work slower - the two threads compete for the GIL, paying thread-switching overhead with no parallelism benefit.
When Threading DOES Help
The GIL is released during I/O operations and during C extension calls that explicitly release it (NumPy does). Threading is the right tool for:
- Blocking I/O (network, disk) - GIL is released during the syscall
- NumPy-heavy computation - GIL released for many NumPy ops
- asyncio + blocking code (via ThreadPoolExecutor)
import threading
import requests # synchronous HTTP
def fetch_url(url: str) -> int:
response = requests.get(url) # GIL released during socket I/O
return len(response.content)
urls = [f"https://httpbin.org/bytes/{i*100}" for i in range(1, 11)]
# Sequential
start = time.perf_counter()
results = [fetch_url(u) for u in urls]
print(f"Sequential: {time.perf_counter() - start:.2f}s")
# Threaded - I/O releases GIL, real parallelism
from concurrent.futures import ThreadPoolExecutor
start = time.perf_counter()
with ThreadPoolExecutor(max_workers=10) as ex:
results = list(ex.map(fetch_url, urls))
print(f"Threaded: {time.perf_counter() - start:.2f}s")
Sequential: 4.32s
Threaded: 0.61s (7x speedup - real I/O parallelism)
Section 2: multiprocessing.Pool
Pool creates a fixed number of worker processes and distributes tasks across them. It is the standard tool for CPU-bound parallelism on a single machine.
Core API
import multiprocessing
import time
import os
from pathlib import Path
from PIL import Image # pip install Pillow
import numpy as np
def resize_image(args: tuple[str, str, tuple[int, int]]) -> str:
"""
Resize a single image. Returns the output path.
Arguments packed in a tuple for Pool.map compatibility.
"""
input_path, output_path, target_size = args
try:
with Image.open(input_path) as img:
resized = img.resize(target_size, Image.LANCZOS)
resized.save(output_path, quality=85, optimize=True)
return output_path
except Exception as e:
return f"ERROR: {e}"
def process_image_batch(
input_dir: str,
output_dir: str,
target_size: tuple[int, int] = (800, 600),
n_workers: int = None, # None = cpu_count()
) -> list[str]:
"""Parallel image resizing using a process pool."""
input_files = list(Path(input_dir).glob("*.jpg"))
Path(output_dir).mkdir(parents=True, exist_ok=True)
# Build task list
tasks = [
(str(f), str(Path(output_dir) / f.name), target_size)
for f in input_files
]
n_workers = n_workers or os.cpu_count()
with multiprocessing.Pool(processes=n_workers) as pool:
# map: blocks until all results are ready
results = pool.map(resize_image, tasks)
return results
map vs imap vs imap_unordered
| Method | Returns | Ordering | Memory | Best For |
|---|---|---|---|---|
pool.map() | List | Preserved | All in RAM | Small-medium result sets |
pool.starmap() | List | Preserved | All in RAM | Multiple arguments per task |
pool.imap() | Iterator (lazy) | Preserved | O(chunk) | Large result sets |
pool.imap_unordered() | Iterator | Not preserved | O(chunk) | Large sets, order unimportant |
pool.apply_async() | AsyncResult | N/A | One at a time | Fire-and-forget tasks |
import multiprocessing
import time
def analyse_document(doc_id: int) -> dict:
"""Simulated document analysis - variable duration."""
import random
time.sleep(random.uniform(0.01, 0.5)) # variable processing time
return {"doc_id": doc_id, "word_count": random.randint(100, 5000)}
def process_documents_lazy(doc_ids: list[int]) -> None:
"""
imap_unordered: results stream in as they complete.
Faster first result and lower peak memory than pool.map().
"""
with multiprocessing.Pool() as pool:
for result in pool.imap_unordered(analyse_document, doc_ids, chunksize=10):
# Process each result immediately rather than waiting for all
save_to_database(result)
chunksize Tuning
# chunksize controls how many items are sent to each worker per batch
# Too small: many round-trips to workers (high IPC overhead)
# Too large: poor load balancing (one slow item blocks a worker's entire chunk)
# Rule of thumb for chunksize:
# chunksize = max(1, len(items) // (4 * n_workers))
items = list(range(100_000))
n_workers = 8
chunksize = max(1, len(items) // (4 * n_workers)) # 3125
with multiprocessing.Pool(n_workers) as pool:
results = pool.map(process_item, items, chunksize=chunksize)
Section 3: ProcessPoolExecutor - The Modern API
concurrent.futures.ProcessPoolExecutor provides a cleaner interface than multiprocessing.Pool. It is preferred for new code.
from concurrent.futures import ProcessPoolExecutor, as_completed
import time
def analyse_shard(shard_path: str) -> dict:
"""Analyse a data shard - may take variable time."""
import json
with open(shard_path) as f:
data = json.load(f)
return {
"shard": shard_path,
"records": len(data),
"total_value": sum(r.get("value", 0) for r in data),
}
def analyse_all_shards(shard_paths: list[str], max_workers: int = 8) -> list[dict]:
"""
Analyse shards in parallel.
as_completed() yields results in completion order - not submission order.
This means you can start processing results before all tasks finish.
"""
results = []
failed = []
with ProcessPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks, get future handles
future_to_shard = {
executor.submit(analyse_shard, path): path
for path in shard_paths
}
# Process results as they complete (not in submission order)
for future in as_completed(future_to_shard):
shard = future_to_shard[future]
try:
result = future.result(timeout=30) # raise if took > 30s
results.append(result)
print(f"Completed: {shard} - {result['records']} records")
except TimeoutError:
print(f"Timeout: {shard}")
failed.append(shard)
except Exception as e:
print(f"Failed: {shard} - {e}")
failed.append(shard)
return results
# With timeout on the entire batch
from concurrent.futures import wait, FIRST_EXCEPTION
def analyse_with_deadline(shard_paths: list[str], deadline_seconds: float) -> list[dict]:
"""Cancel remaining tasks if deadline is reached."""
with ProcessPoolExecutor() as executor:
futures = [executor.submit(analyse_shard, p) for p in shard_paths]
done, not_done = wait(futures, timeout=deadline_seconds,
return_when=FIRST_EXCEPTION)
# Cancel remaining tasks
for future in not_done:
future.cancel()
return [f.result() for f in done if not f.exception()]
Section 4: Shared Memory
Inter-process communication is the primary overhead in multiprocessing. For large arrays, the most efficient solution is shared memory: all worker processes read from the same physical memory pages - no serialisation, no copying.
import multiprocessing
import multiprocessing.shared_memory
import numpy as np
import os
def worker_with_shared_memory(
shm_name: str,
shape: tuple,
dtype: np.dtype,
start_idx: int,
end_idx: int,
) -> None:
"""
Worker function that operates on a shared memory array.
Receives only the metadata (name, shape, dtype) - not the data itself.
"""
# Attach to the existing shared memory segment
shm = multiprocessing.shared_memory.SharedMemory(name=shm_name)
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
# Operate on the slice this worker is responsible for
# Multiple workers can safely read simultaneously
# Multiple workers writing to NON-OVERLAPPING regions is also safe
for i in range(start_idx, end_idx):
arr[i] = arr[i] ** 2 # in-place modification in shared memory
shm.close() # detach (do NOT unlink - parent owns the memory)
def parallel_square_inplace(data: np.ndarray, n_workers: int = 4) -> np.ndarray:
"""
Square each element of a large array in parallel, using shared memory.
No data copying between processes.
"""
# Create shared memory and copy data into it
shm = multiprocessing.shared_memory.SharedMemory(
create=True,
size=data.nbytes,
)
shared_arr = np.ndarray(data.shape, dtype=data.dtype, buffer=shm.buf)
np.copyto(shared_arr, data) # one copy - data → shared memory
n = len(data)
chunk_size = (n + n_workers - 1) // n_workers
processes = []
for i in range(n_workers):
start = i * chunk_size
end = min(start + chunk_size, n)
p = multiprocessing.Process(
target=worker_with_shared_memory,
args=(shm.name, data.shape, data.dtype, start, end),
)
processes.append(p)
p.start()
for p in processes:
p.join()
result = shared_arr.copy() # copy result out before cleanup
shm.close()
shm.unlink() # parent deletes the shared memory
return result
# Benchmark: shared memory vs pool.map (which pickles the data)
import time
large_array = np.random.rand(10_000_000)
start = time.perf_counter()
result_sm = parallel_square_inplace(large_array.copy())
print(f"Shared memory: {time.perf_counter() - start:.3f}s")
def square_chunk(args):
chunk, = args
return chunk ** 2
start = time.perf_counter()
with multiprocessing.Pool(4) as pool:
n = len(large_array)
chunks = np.array_split(large_array, 4)
results = pool.map(lambda c: c**2, chunks)
result_mp = np.concatenate(results)
print(f"pool.map (with pickle): {time.perf_counter() - start:.3f}s")
Shared memory: 0.089s
pool.map (with pickle): 0.312s
Shared memory is 3.5x faster for large arrays because it eliminates pickling. The tradeoff: shared memory requires careful coordination - concurrent writes to overlapping regions cause data races.
Thread Safety with Shared Memory
import multiprocessing
from multiprocessing import Lock
# For operations that require atomicity across workers:
lock = Lock()
def thread_safe_counter_update(shm_name: str, index: int, lock: Lock) -> None:
"""Update a shared counter with a lock."""
shm = multiprocessing.shared_memory.SharedMemory(name=shm_name)
counters = np.ndarray((100,), dtype=np.int64, buffer=shm.buf)
with lock: # only one process at a time
counters[index] += 1
shm.close()
For embarrassingly parallel operations (each worker writes to a disjoint region of the shared array), no lock is needed.
Section 5: Ray - Distributed Python at Scale
Ray is a framework for distributed Python. It handles process management, serialisation, scheduling, fault tolerance, and cluster management. It scales from a single machine to thousands of nodes with the same API.
pip install ray
Basic Ray Functions
import ray
import time
import numpy as np
ray.init() # starts local Ray cluster (or connects to an existing one)
@ray.remote
def analyse_document(doc: dict) -> dict:
"""Remote function - runs in a Ray worker process."""
words = doc['content'].split()
return {
"doc_id": doc['id'],
"word_count": len(words),
"unique_words": len(set(words)),
"avg_word_length": sum(len(w) for w in words) / len(words) if words else 0,
}
# Sequential (for comparison)
documents = [{"id": i, "content": f"This is document {i} " * 100} for i in range(1000)]
start = time.perf_counter()
results_seq = [analyse_document.remote(doc) for doc in documents]
results = ray.get(results_seq) # wait for all
print(f"Ray parallel: {time.perf_counter() - start:.2f}s")
ray.put() - Avoid Re-Serialising Large Objects
import ray
import numpy as np
# BAD: large array is serialised for every remote call
large_matrix = np.random.rand(10_000, 10_000)
@ray.remote
def process_with_matrix_bad(row_idx: int, matrix: np.ndarray) -> np.ndarray:
return matrix[row_idx] ** 2 # matrix is pickled for every call
# GOOD: put the array in the Ray object store once, pass by reference
matrix_ref = ray.put(large_matrix) # one serialisation, stored in object store
@ray.remote
def process_with_matrix_good(row_idx: int, matrix_ref) -> np.ndarray:
matrix = ray.get(matrix_ref) # reads from shared object store (zero-copy on same machine)
return matrix[row_idx] ** 2
futures = [process_with_matrix_good.remote(i, matrix_ref) for i in range(100)]
results = ray.get(futures)
Ray Actor Model - Stateful Distributed Workers
import ray
from typing import Optional
@ray.remote
class ModelServer:
"""
Ray actor - a stateful process that can hold a loaded ML model.
Multiple actors can serve requests in parallel without reloading the model.
"""
def __init__(self, model_path: str):
import torch
self.model = torch.load(model_path)
self.model.eval()
self.request_count = 0
def predict(self, features: list[float]) -> float:
import torch
self.request_count += 1
with torch.no_grad():
tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0)
return float(self.model(tensor).item())
def get_stats(self) -> dict:
return {"requests_served": self.request_count}
# Create 4 model server replicas
servers = [ModelServer.remote("model.pt") for _ in range(4)]
# Route requests round-robin
import itertools
server_cycle = itertools.cycle(servers)
async def handle_inference_request(features: list[float]) -> float:
server = next(server_cycle)
return await server.predict.remote(features)
Distributed Hyperparameter Search with Ray
import ray
import numpy as np
from dataclasses import dataclass
from typing import Optional
@dataclass
class HyperParams:
learning_rate: float
batch_size: int
dropout: float
hidden_dim: int
@ray.remote
def evaluate_hyperparams(params: HyperParams, dataset_ref) -> dict:
"""
Train and evaluate a model with given hyperparameters.
Runs in a separate process - true parallelism.
"""
dataset = ray.get(dataset_ref)
# Simulate training
import time
import random
time.sleep(2.0) # training time
# Simulate evaluation metric (replace with actual training)
val_accuracy = (
0.9
- 0.1 * abs(params.learning_rate - 0.001) / 0.001
+ random.gauss(0, 0.02)
)
return {
"params": params,
"val_accuracy": val_accuracy,
"training_time": 2.0,
}
def distributed_grid_search(
dataset: np.ndarray,
hyperparameter_grid: list[HyperParams],
n_concurrent: int = 8,
) -> HyperParams:
"""
Evaluate all hyperparameter combinations in parallel using Ray.
"""
dataset_ref = ray.put(dataset) # put dataset in object store once
futures = [
evaluate_hyperparams.remote(params, dataset_ref)
for params in hyperparameter_grid
]
results = ray.get(futures)
best = max(results, key=lambda r: r['val_accuracy'])
print(f"Best accuracy: {best['val_accuracy']:.4f}")
print(f"Best params: {best['params']}")
return best['params']
Section 6: Dask - Lazy Computation Graphs
Dask provides large-scale array and dataframe operations that look like NumPy/pandas but execute lazily - building a computation graph first, then executing it with dask.compute(). This allows processing datasets larger than RAM.
pip install dask[complete]
Dask Arrays - NumPy That Doesn't Fit in RAM
import dask.array as da
import numpy as np
import time
# Create a 50GB virtual array (no actual memory allocated yet)
# Stored in 100MB chunks
large_array = da.random.random(
size=(100_000, 62_500), # 100k × 62.5k float64 = 50 GB
chunks=(1_000, 62_500), # each chunk = 500 MB
)
print(f"Array shape: {large_array.shape}")
print(f"Array dtype: {large_array.dtype}")
print(f"Chunk shape: {large_array.chunks[0][0]} × {large_array.chunks[1][0]}")
print(f"Virtual size: {large_array.nbytes / 1024**3:.0f} GB")
# Operations build a computation graph - no computation yet
mean_per_row = large_array.mean(axis=1) # mean of each row
std_per_row = large_array.std(axis=1)
zscore_col_0 = (large_array[:, 0] - mean_per_row) / std_per_row
# .compute() triggers actual execution
# Dask processes one chunk at a time - peak memory ≈ chunk_size × n_workers
result = zscore_col_0.compute()
print(f"Result shape: {result.shape}")
print(f"Result dtype: {result.dtype}")
Dask DataFrames - pandas Beyond RAM
import dask.dataframe as dd
# Read a 50GB CSV without loading it into RAM
# Dask reads it in chunks, processing one partition at a time
df = dd.read_csv(
"data/transactions_*.csv", # glob pattern matches multiple files
dtype={
"amount": "float64",
"user_id": "int64",
"timestamp": "str",
},
blocksize="64MB", # each partition is ~64MB
)
print(f"Partitions: {df.npartitions}")
# Operations are lazy - build the computation graph
result = (
df[df["amount"] > 100] # filter rows
.groupby("user_id")["amount"] # group by user
.agg(["sum", "count", "mean"]) # aggregate
.reset_index()
.rename(columns={
"sum": "total_spend",
"count": "num_transactions",
"mean": "avg_transaction",
})
)
# .compute() triggers execution
# Dask reads each partition, applies the pipeline, merges results
aggregated = result.compute()
print(f"Unique users: {len(aggregated)}")
print(aggregated.head())
Visualising the Task Graph
import dask
import dask.array as da
x = da.random.random((10_000, 10_000), chunks=(2_500, 2_500))
y = da.dot(x, x.T)
result = y.mean()
# Render the task graph as an image
dask.visualize(result, filename='task_graph.png')
Dask Bag - For Unstructured Data
import dask.bag as db
import json
def parse_log_line(line: str) -> dict | None:
"""Parse a single log line - returns None if malformed."""
try:
return json.loads(line)
except json.JSONDecodeError:
return None
# Read a directory of log files
bag = db.read_text("logs/*.jsonl") # lazy - no data read yet
# Build pipeline
error_bag = (
bag
.map(parse_log_line)
.filter(lambda x: x is not None)
.filter(lambda x: x.get("level") == "ERROR")
.pluck("message")
)
# Execute
errors = error_bag.compute()
print(f"Total errors: {len(errors)}")
Dask Distributed Scheduler
For multi-machine clusters:
# Machine 1 (scheduler)
dask scheduler --host 0.0.0.0 --port 8786
# Machines 2-N (workers)
dask worker tcp://scheduler-host:8786 --nthreads 4 --memory-limit 16GB
from dask.distributed import Client
client = Client("tcp://scheduler-host:8786")
print(client)
# All dask computations now run on the cluster
result = large_dask_array.mean().compute()
Section 7: Celery - Task Queues for Background Work
Celery is a distributed task queue. It decouples slow operations (email sending, PDF generation, ML inference) from request handlers by running them asynchronously in worker processes.
pip install celery redis
# Redis as the message broker (RabbitMQ is another option)
Basic Setup
# celery_app.py
from celery import Celery
import os
app = Celery(
"myservice",
broker=os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379/0"),
backend=os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379/1"),
)
app.conf.update(
task_serializer="json",
result_serializer="json",
accept_content=["json"],
task_track_started=True,
task_acks_late=True, # re-queue on worker crash
worker_prefetch_multiplier=1, # one task per worker at a time (fair scheduling)
task_routes={
"myservice.tasks.generate_report": {"queue": "heavy"},
"myservice.tasks.send_email": {"queue": "light"},
},
)
Defining Tasks
# tasks.py
from celery_app import app
from celery.utils.log import get_task_logger
import time
logger = get_task_logger(__name__)
@app.task(
bind=True,
max_retries=3,
default_retry_delay=60, # wait 60s before retry
soft_time_limit=300, # SIGALRM after 5 minutes
time_limit=360, # SIGKILL after 6 minutes
queue="heavy",
)
def generate_report(self, user_id: int, report_type: str, params: dict) -> str:
"""
Long-running report generation.
Retries automatically on failure.
"""
try:
logger.info("Generating %s report for user %d", report_type, user_id)
# Update task progress
self.update_state(state="PROGRESS", meta={"current": 0, "total": 100})
data = fetch_report_data(user_id, params)
self.update_state(state="PROGRESS", meta={"current": 50, "total": 100})
output_path = render_report(data, report_type)
self.update_state(state="PROGRESS", meta={"current": 90, "total": 100})
notify_user(user_id, output_path)
return output_path
except Exception as exc:
logger.exception("Report generation failed for user %d", user_id)
# Retry with exponential backoff
raise self.retry(exc=exc, countdown=2 ** self.request.retries * 60)
@app.task(
bind=True,
max_retries=5,
default_retry_delay=30,
queue="light",
)
def send_notification_email(self, email: str, subject: str, body: str) -> None:
"""Email task with retry."""
try:
smtp_client.send(to=email, subject=subject, body=body)
except SMTPTemporaryError as exc:
raise self.retry(exc=exc)
except SMTPPermanentError:
logger.error("Permanent email failure for %s - not retrying", email)
# Do not retry permanent failures (bad address, blocked, etc.)
Dispatching Tasks
# In a FastAPI handler:
from fastapi import FastAPI
from tasks import generate_report, send_notification_email
app = FastAPI()
@app.post("/api/reports")
async def request_report(user_id: int, report_type: str):
# .delay() sends the task to the queue - returns immediately
task = generate_report.delay(
user_id=user_id,
report_type=report_type,
params={"year": 2024, "format": "pdf"},
)
return {"task_id": task.id, "status": "queued"}
@app.get("/api/reports/{task_id}")
async def check_report_status(task_id: str):
from celery.result import AsyncResult
result = AsyncResult(task_id)
if result.state == "PENDING":
return {"status": "pending"}
elif result.state == "PROGRESS":
return {"status": "in_progress", "progress": result.info}
elif result.state == "SUCCESS":
return {"status": "complete", "path": result.result}
elif result.state == "FAILURE":
return {"status": "failed", "error": str(result.info)}
Celery Pipelines (Canvas)
from celery import chain, group, chord
# Chain: task1 → task2 → task3 (sequential, output of one is input of next)
pipeline = chain(
fetch_raw_data.s(source_id=42),
clean_data.s(),
extract_features.s(),
generate_report.s(report_type="summary"),
)
result = pipeline.delay()
# Group: run multiple tasks in parallel
parallel_tasks = group([
analyse_segment.s(segment_id=i)
for i in range(10)
])
results = parallel_tasks.delay()
# Chord: group + callback (run N tasks in parallel, then one aggregator)
report_chord = chord(
group([fetch_data.s(region=r) for r in ["US", "EU", "APAC"]]),
aggregate_and_report.s(), # called with list of results from the group
)
report_chord.delay()
Running Workers
# Start workers - one for each queue
celery -A celery_app worker --queues heavy --concurrency 4 --loglevel info &
celery -A celery_app worker --queues light --concurrency 16 --loglevel info &
# Monitor
celery -A celery_app flower --port 5555
# Open http://localhost:5555 for the Flower dashboard
Section 8: Choosing the Right Tool
Every scale problem has a best-fit tool. Choosing wrong adds complexity without benefit.
Decision Matrix
| Situation | Tool | Reason |
|---|---|---|
| CPU-bound on one machine, data fits in RAM | ProcessPoolExecutor | Simplest true parallelism, stdlib |
| CPU-bound, need shared large array | multiprocessing.SharedMemory | Zero-copy array sharing |
| I/O-bound on one machine | asyncio + httpx | Event-loop concurrency, no process overhead |
| Distributed CPU compute, dynamic task graphs | Ray | Actor model, object store, cluster management |
| Data larger than RAM, NumPy/pandas operations | Dask | Lazy chunked execution, fits the pandas/NumPy mental model |
| Background jobs (email, reports, ML pipelines) | Celery | Persistent queue, retry logic, scheduling |
| ML training (large models) | PyTorch DDP / Ray Train | Specialised distributed training |
| Streaming data processing | Faust / Kafka + asyncio | Event-driven streaming |
| One-off large data transformations | Dask or PySpark | Out-of-core execution |
The Complexity Budget
| Tool | Local Dev Overhead | Production Overhead | Team Expertise Required |
|---|---|---|---|
ProcessPoolExecutor | Minimal | Minimal (stdlib) | Low |
multiprocessing | Minimal | Minimal (stdlib) | Low |
| Ray | Medium (ray start) | High (cluster management) | Medium |
| Dask | Low (local cluster) | Medium (distributed cluster) | Medium |
| Celery | High (broker + workers) | High (broker, monitoring) | Medium |
| Spark (PySpark) | Very high | Very high | High |
Do not introduce Celery for background tasks that run in < 1 second and do not need retry logic - asyncio.create_task() is sufficient. Do not introduce Ray for parallelism that ProcessPoolExecutor handles correctly - Ray has a higher operational cost.
Interview Questions
Q1: The GIL prevents Python threads from running in parallel for CPU-bound code. Name two scenarios where threading IS beneficial for Python programs and explain why.
Scenario 1 - I/O-bound operations: When a Python thread performs a network read (socket.recv()), file read (f.read()), or any other blocking syscall, CPython releases the GIL before entering the syscall. While that thread is blocked in the kernel, other Python threads can execute. For a service that makes 10 outbound HTTP requests, threading provides genuine parallelism: 10 threads each block in their respective recv() calls simultaneously, all with the GIL released.
Scenario 2 - C extensions that release the GIL: NumPy explicitly releases the GIL for many of its array operations. A thread calling np.dot(A, B) releases the GIL before dispatching to BLAS, which runs natively in C. Other Python threads can run Python code concurrently. For workflows that alternate between NumPy computation and Python orchestration logic, threading can provide genuine speedup - the NumPy work runs in parallel with other threads' Python code.
Q2: What is serialisation overhead in multiprocessing and what strategies exist to minimise it?
When multiprocessing.Pool sends a task to a worker, it serialises the function arguments using pickle, writes them to a pipe, the worker reads the pipe and deserialises them, does the work, serialises the result, and sends it back. This pickle → pipe → unpickle round trip is paid for every task.
For small arguments, this overhead dominates the actual computation time - as demonstrated in the opening benchmark (92x slower for 10-byte items).
Strategies to minimise it:
-
Increase item size / reduce task count: batch many small items into one large task instead of many small tasks. Instead of 100,000 tasks of 10 bytes each, send 8 tasks of 125KB each.
-
multiprocessing.shared_memory: put large arrays in shared memory. Workers receive only a reference (name, shape, dtype - a few dozen bytes total), not the data. Zero-copy on the same machine. -
ray.put(): Ray's object store keeps the serialised object in shared memory on the local machine and sends only a reference (16-byte object ID) to workers. For large NumPy arrays, Ray uses the Apache Arrow format and zero-copy reads. -
initializerand global state: use thePool(initializer=init_fn)parameter to load large read-only data (e.g., a lookup table, a model) into each worker process once at startup. Workers then operate on the global, not receive data as arguments. -
Use
imap_unorderedwith chunked generators: avoid loading all input data into memory at once; stream it in chunks.
Q3: What is Dask's lazy evaluation model and why is it important for large-scale data processing?
Dask does not execute operations immediately when you call .map(), .groupby(), or .mean(). Instead, it builds a directed acyclic graph (DAG) representing the computation - a task graph where each node is an operation and edges represent data dependencies. No actual computation or data reading happens until you call .compute().
This lazy evaluation model is important for several reasons:
1. Optimisation before execution: Dask can analyse the full computation graph and optimise it - fusing adjacent operations to reduce passes over the data, reusing intermediate results, eliminating redundant computations. If you call df.filter(...).filter(...), Dask can combine the two filter passes into one.
2. Out-of-core execution: Dask processes data in chunks (partitions). Each partition fits in memory. Dask schedules partitions one (or a few) at a time, discarding processed partitions from memory before loading the next ones. This allows processing 100GB datasets on a 16GB machine.
3. Distributed execution: the task graph is a portable description of the computation that Dask's distributed scheduler can distribute across workers on multiple machines. The same code runs locally or on a cluster by changing the scheduler.
4. Avoidance of unnecessary intermediate results: if you only need the final aggregated result, Dask does not materialise the full intermediate dataframe - it streams data through the pipeline partition by partition.
Q4: Explain the Celery task retry pattern with exponential backoff. When should you NOT retry a Celery task?
Celery retries are configured via max_retries, default_retry_delay, and the self.retry() call inside the task:
@app.task(bind=True, max_retries=3)
def send_email(self, to: str, subject: str, body: str) -> None:
try:
smtp_client.send(to=to, subject=subject, body=body)
except SMTPTemporaryError as exc:
# Exponential backoff: 1min, 2min, 4min before giving up
raise self.retry(exc=exc, countdown=60 * 2 ** self.request.retries)
self.request.retries is 0 on first attempt, 1 on first retry, 2 on second retry. 2 ** self.request.retries gives 1, 2, 4 - exponential backoff. This is important for avoiding thundering herd: if many tasks failed simultaneously (e.g., downstream outage), retrying all at the same time would hit the recovering service all at once.
Do NOT retry when:
-
The failure is permanent / deterministic:
SMTPPermanentError(invalid email address, blocked sender) - retrying will always fail. Log and discard. -
The task already produced a side effect: if the task sent a notification, inserted a database record, or charged a payment card before failing on a subsequent step, retrying would duplicate the side effect. Make tasks idempotent (safe to run multiple times) or track completion state in the result backend.
-
The error is the caller's fault: invalid input (validation errors, malformed data) - retrying will always fail with the same error. Log and alert the caller.
-
The retry budget is exhausted: Celery will not retry beyond
max_retries. On final failure, log the error, alert the team, and optionally move to a dead-letter queue for manual inspection.
Q5: You need to process 1 billion rows of financial transaction data to detect fraud patterns. The data is stored as Parquet files totalling 500GB. How would you architect this in Python?
This is a large-scale batch processing problem. The right tool is Dask (on a single cluster of machines) or PySpark. Here is the Dask architecture:
Infrastructure: 8-16 machines with 32GB RAM each, running Dask distributed workers. One Dask scheduler node.
Ingestion:
import dask.dataframe as dd
from dask.distributed import Client
client = Client("tcp://dask-scheduler:8786")
df = dd.read_parquet(
"s3://data-lake/transactions/year=2024/", # Parquet partitioned by year
columns=["txn_id", "user_id", "amount", "merchant_id", "timestamp", "country"],
engine="pyarrow",
)
# Dask reads metadata only - no data in RAM yet
Feature Engineering (lazy):
df["hour"] = dd.to_datetime(df["timestamp"]).dt.hour
df["is_foreign"] = df["country"] != "US"
# User-level aggregations
user_stats = (
df.groupby("user_id")
.agg({"amount": ["sum", "mean", "std", "count"], "is_foreign": "sum"})
.reset_index()
)
Model Scoring: For ML model inference at scale, use Dask's map_partitions to apply a pre-trained fraud model to each partition in parallel:
import joblib
model = joblib.load("fraud_model.pkl")
model_ref = client.scatter(model, broadcast=True) # send model to all workers once
def score_partition(partition, model):
features = partition[feature_columns].values
return partition.assign(fraud_score=model.predict_proba(features)[:, 1])
scored = df.map_partitions(score_partition, model_ref)
flagged = scored[scored["fraud_score"] > 0.95]
Output: write flagged transactions back to Parquet for downstream processing:
flagged.to_parquet(
"s3://data-lake/fraud_flags/2024/",
write_index=False,
compression="snappy",
)
Why Dask over alternatives:
- Dask's pandas-like API minimises code changes from a local prototype
- Parquet + columnar predicate pushdown means Dask only reads relevant columns
- Out-of-core execution handles the 500GB dataset on a cluster without materialising all data simultaneously
- For this use case, PySpark is a valid alternative - it has a more mature ecosystem for very large clusters (100+ nodes) but heavier JVM overhead for moderate scales
